{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "398cf110",
   "metadata": {},
   "source": [
    "# GNN performance as a function of data\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "8ba5c1db",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Stock imports\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "from lee_et_al_2023.src import analysis\n",
    "from lee_et_al_2023.src import base\n",
    "from lee_et_al_2023.src import data_loaders\n",
    "base.set_visual_settings()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "fa74a4d2",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/rick/code/ratatouille/qian_et_al_2023/src/data_loaders.py:77: FutureWarning: The default value of numeric_only in DataFrameGroupBy.mean is deprecated. In a future version, numeric_only will default to False. Either specify numeric_only or select only columns which should be valid for the function.\n",
      "  panel = humans.groupby('RedJade Code').mean().loc[mol_codes, base.MONELL_CLASS_LIST]\n"
     ]
    }
   ],
   "source": [
    "# Canonical way to load (most of) the data... add lines as needs to include other pieces\n",
    "models, humans, panel, subjects = data_loaders.get_clean()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0ed401ab",
   "metadata": {},
   "source": [
    "### Run the code to prepare the data to be visualized"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "3853a409",
   "metadata": {
    "lines_to_next_cell": 2
   },
   "outputs": [],
   "source": [
    "transpose_ott = analysis.fast_process(humans, models, axis=1)\n",
    "corr_df = transpose_ott.groupby('index').agg(np.nanmedian)\n",
    "corr_df['Test data counts'] = (panel > 0.7).sum(axis=0)\n",
    "training_class_counts = pd.read_csv(base.DATA_PATH / \"training_class_counts.csv\")\n",
    "training_class_counts['label'] = training_class_counts['label'].apply(\n",
    "    lambda l: {\"jasmin\": \"jasmine\"}.get(l, l).title())\n",
    "training_class_counts = training_class_counts.set_index('label')\n",
    "training_class_counts = training_class_counts.loc[base.MONELL_CLASS_LIST]\n",
    "corr_df['Training data counts'] = training_class_counts"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3dde87b3",
   "metadata": {},
   "source": [
    "### Name the figure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f48f5f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig_name = '3B'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b0caf3ff",
   "metadata": {},
   "source": [
    "### Run the code to make the figure and save the figure to friendly formats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39db34cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(8,8))\n",
    "plot = sns.scatterplot(data=corr_df.reset_index(), x='Training data counts', y='GNN',\n",
    "                size='Test data counts', sizes=(20, 400), color='#fc8d62')\n",
    "plot.legend(loc='lower right', title='Test data counts')\n",
    "plot.set(xscale=\"log\")\n",
    "plt.xlabel('Training Data Counts', fontsize=20)\n",
    "plt.xlim(10**1.5, 10**3.5)\n",
    "plt.ylabel('GNN Correlation with Panel Mean', fontsize=20)\n",
    "plt.xticks(10 ** np.linspace(1.5, 3.5, 5), fontsize=16)\n",
    "plt.yticks(np.linspace(-0.1, 0.5, 7), fontsize=16)\n",
    "for label in ('Camphoreous', 'Fishy', 'Cooling', 'Sulfurous', 'Roasted', 'Fruity', 'Floral', 'Sweet', 'Green',\n",
    "               'Ozone', 'Sharp', 'Waxy', 'Medicinal', 'Musty', 'Fermented', 'Garlic', 'Alcoholic', 'Musk', 'Meaty'):\n",
    "    row = corr_df.loc[label]\n",
    "    offset = (5, 5)\n",
    "    if label == 'Floral':\n",
    "        offset = (5, -15)\n",
    "    if label == 'Fishy':\n",
    "        offset = (5, -15)\n",
    "    if label == 'Camphoreous':\n",
    "        offset = (5, -5)\n",
    "    plt.annotate(label, (row['Training data counts'], row['GNN']), xytext=offset, textcoords='offset points',\n",
    "                fontsize=20)\n",
    "\n",
    "for axis in ['bottom','left']:\n",
    "    plot.spines[axis].set_linewidth(3)\n",
    "    plot.spines[axis].set_edgecolor('black')\n",
    "for axis in ['top','right']:\n",
    "    plot.spines[axis].set_visible(False)\n",
    "\n",
    "plot.grid(False)\n"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "formats": "ipynb,py:light"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
